{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f896245e-57c4-48fd-854f-9e43f22e10c9",
   "metadata": {},
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "<br>汉化的库: <a href=\"https://github.com/GoatCsu/CN-LLMs-from-scratch.git\">https://github.com/GoatCsu/CN-LLMs-from-scratch.git</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca7fc8a0-280c-4979-b0c7-fc3a99b3b785",
   "metadata": {},
   "source": [
    "# 附录A: Pytorch介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5bf13d2-8fc2-483e-88cc-6b4310221e68",
   "metadata": {},
   "source": [
    "## A.1 什么是PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "96ee5660-5327-48e2-9104-a882b3b2afa4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.4.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f73ad4e4-7ec6-4467-a9e9-0cdf6d195264",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "397ba1ab-3306-4965-8618-1ed5f24fb939",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A1/1.png\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e3c0555-88f6-4515-8c99-aa56b0769d54",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A1/2.png\" width=\"300px\">\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A1/3.png\" width=\"300px\">\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A1/4.png\" width=\"500px\">\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A1/5.png\" width=\"500px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2100cf2e-7459-4ab3-92a8-43e86ab35a9b",
   "metadata": {},
   "source": [
    "## A.2 理解张量"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c484e87-bfc9-4105-b0a7-1e23b2a72a30",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A2/1.png\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26d7f785-e048-42bc-9182-a556af6bb7f4",
   "metadata": {},
   "source": [
    "### A.2.1 标量、向量、矩阵和张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a3a464d6-cec8-4363-87bd-ea4f900baced",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "# 从Python整数创建一个零维张量（标量）\n",
    "tensor0d = torch.tensor(1)\n",
    "\n",
    "# 从Python列表创建一个一维张量（向量）\n",
    "tensor1d = torch.tensor([1, 2, 3])\n",
    "\n",
    "# 从嵌套的Python列表创建一个二维张量\n",
    "tensor2d = torch.tensor([[1, 2], \n",
    "                         [3, 4]])\n",
    "\n",
    "# 从嵌套的Python列表创建一个三维张量\n",
    "tensor3d_1 = torch.tensor([[[1, 2], [3, 4]], \n",
    "                           [[5, 6], [7, 8]]])\n",
    "\n",
    "# 从NumPy数组创建一个张量\n",
    "ary3d = np.array([[[1, 2], [3, 4]], \n",
    "                  [[5, 6], [7, 8]]])\n",
    "tensor3d_2 = torch.tensor(ary3d)  # 复制 NumPy 数组\n",
    "tensor3d_3 = torch.from_numpy(ary3d)  # 与 NumPy 数组共享内存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dbe14c47-499a-4d48-b354-a0e6fd957872",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[1, 2],\n",
      "         [3, 4]],\n",
      "\n",
      "        [[5, 6],\n",
      "         [7, 8]]])\n"
     ]
    }
   ],
   "source": [
    "ary3d[0, 0, 0] = 999\n",
    "print(tensor3d_2) # 保持不变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e3e4c23a-cdba-46f5-a2dc-5fb32bf9117b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[999,   2],\n",
      "         [  3,   4]],\n",
      "\n",
      "        [[  5,   6],\n",
      "         [  7,   8]]])\n"
     ]
    }
   ],
   "source": [
    "print(tensor3d_3) # 由于共享内存而发生变化"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63dec48d-2b60-41a2-ac06-fef7e718605a",
   "metadata": {},
   "source": [
    "### A.2.2 张量数据类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3f48c014-e1a2-4a53-b5c5-125812d4034c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.int64\n"
     ]
    }
   ],
   "source": [
    "tensor1d = torch.tensor([1, 2, 3])\n",
    "print(tensor1d.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5429a086-9de2-4ac7-9f14-d087a7507394",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n"
     ]
    }
   ],
   "source": [
    "floatvec = torch.tensor([1.0, 2.0, 3.0])\n",
    "print(floatvec.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a9a438d1-49bb-481c-8442-7cc2bb3dd4af",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n"
     ]
    }
   ],
   "source": [
    "floatvec = tensor1d.to(torch.float32)\n",
    "print(floatvec.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2020deb5-aa02-4524-b311-c010f4ad27ff",
   "metadata": {},
   "source": [
    "### A.2.3 常见的PyTorch张量操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c02095f2-8a48-4953-b3c9-5313d4362ce7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3],\n",
       "        [4, 5, 6]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d = torch.tensor([[1, 2, 3], \n",
    "                         [4, 5, 6]])\n",
    "tensor2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f33e1d45-5b2c-4afe-b4b2-66ac4099fd1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f3a4129d-f870-4e03-9c32-cd8521cb83fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2],\n",
       "        [3, 4],\n",
       "        [5, 6]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d.reshape(3, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "589ac0a7-adc7-41f3-b721-155f580e9369",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2],\n",
       "        [3, 4],\n",
       "        [5, 6]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d.view(3, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "344e307f-ba5d-4f9a-a791-2c75a3d1417e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 4],\n",
       "        [2, 5],\n",
       "        [3, 6]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "19a75030-6a41-4ca8-9aae-c507ae79225c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[14, 32],\n",
       "        [32, 77]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d.matmul(tensor2d.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e7c950bc-d640-4203-b210-3ac8932fe4d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[14, 32],\n",
       "        [32, 77]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2d @ tensor2d.T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c15bdeb-78e2-4870-8a4f-a9f591666f38",
   "metadata": {},
   "source": [
    "## A.3 将模型视为计算图"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f3e16c3-07df-44b6-9106-a42fb24452a9",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A3/1.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "22af61e9-0443-4705-94d7-24c21add09c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0852)\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "y = torch.tensor([1.0])  # 真实标签\n",
    "x1 = torch.tensor([1.1]) # 输入特征\n",
    "w1 = torch.tensor([2.2]) # 权重参数\n",
    "b = torch.tensor([0.0])  # 偏置单元\n",
    "\n",
    "z = x1 * w1 + b          # 净输入\n",
    "a = torch.sigmoid(z)     # 激活和输出\n",
    "\n",
    "loss = F.binary_cross_entropy(a, y)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9424f26-2bac-47e7-b834-92ece802247c",
   "metadata": {},
   "source": [
    "## A.4 轻松实现自动微分"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33aa2ee4-6f1d-448d-8707-67cd5278233c",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A4/1.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ebf5cef7-48d6-4d2a-8ab0-0fb10bdd7d1a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([-0.0898]),)\n",
      "(tensor([-0.0817]),)\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "from torch.autograd import grad\n",
    "\n",
    "y = torch.tensor([1.0])\n",
    "x1 = torch.tensor([1.1])\n",
    "w1 = torch.tensor([2.2], requires_grad=True)\n",
    "b = torch.tensor([0.0], requires_grad=True)\n",
    "\n",
    "z = x1 * w1 + b \n",
    "a = torch.sigmoid(z)\n",
    "\n",
    "loss = F.binary_cross_entropy(a, y)\n",
    "\n",
    "grad_L_w1 = grad(loss, w1, retain_graph=True)\n",
    "grad_L_b = grad(loss, b, retain_graph=True)\n",
    "\n",
    "print(grad_L_w1)\n",
    "print(grad_L_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "93c5875d-f6b2-492c-b5ef-7e132f93a4e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.0898])\n",
      "tensor([-0.0817])\n"
     ]
    }
   ],
   "source": [
    "loss.backward()\n",
    "\n",
    "print(w1.grad)\n",
    "print(b.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f53bdd7d-44e6-40ab-8a5a-4eef74ef35dc",
   "metadata": {},
   "source": [
    "## A.5 实现多层神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6cb9787-2bc8-4379-9e8c-a3401ac63c51",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A5/1.png\" width=\"500px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "84b749e1-7768-4cfe-94d6-a08c7feff4a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeuralNetwork(torch.nn.Module):\n",
    "    def __init__(self, num_inputs, num_outputs):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layers = torch.nn.Sequential(\n",
    "                \n",
    "            # 1st hidden layer\n",
    "            torch.nn.Linear(num_inputs, 30),\n",
    "            torch.nn.ReLU(),\n",
    "\n",
    "            # 2nd hidden layer\n",
    "            torch.nn.Linear(30, 20),\n",
    "            torch.nn.ReLU(),\n",
    "\n",
    "            # output layer\n",
    "            torch.nn.Linear(20, num_outputs),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        logits = self.layers(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c5b59e2e-1930-456d-93b9-f69263e3adbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NeuralNetwork(50, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "39d02a21-33e7-4879-8fd2-d6309faf2f8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NeuralNetwork(\n",
      "  (layers): Sequential(\n",
      "    (0): Linear(in_features=50, out_features=30, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=30, out_features=20, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=20, out_features=3, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "94535738-de02-4c2a-9b44-1cd186fa990a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trainable model parameters: 2213\n"
     ]
    }
   ],
   "source": [
    "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(\"Total number of trainable model parameters:\", num_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2c394106-ad71-4ccb-a3c9-9b60af3fa748",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[ 0.1182,  0.0606, -0.1292,  ..., -0.1126,  0.0735, -0.0597],\n",
      "        [-0.0249,  0.0154, -0.0476,  ..., -0.1001, -0.1288,  0.1295],\n",
      "        [ 0.0641,  0.0018, -0.0367,  ..., -0.0990, -0.0424, -0.0043],\n",
      "        ...,\n",
      "        [ 0.0618,  0.0867,  0.1361,  ..., -0.0254,  0.0399,  0.1006],\n",
      "        [ 0.0842, -0.0512, -0.0960,  ..., -0.1091,  0.1242, -0.0428],\n",
      "        [ 0.0518, -0.1390, -0.0923,  ..., -0.0954, -0.0668, -0.0037]],\n",
      "       requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "print(model.layers[0].weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b201882b-9285-4db9-bb63-43afe6a2ff9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.0577,  0.0047, -0.0702,  ...,  0.0222,  0.1260,  0.0865],\n",
      "        [ 0.0502,  0.0307,  0.0333,  ...,  0.0951,  0.1134, -0.0297],\n",
      "        [ 0.1077, -0.1108,  0.0122,  ...,  0.0108, -0.1049, -0.1063],\n",
      "        ...,\n",
      "        [-0.0787,  0.1259,  0.0803,  ...,  0.1218,  0.1303, -0.1351],\n",
      "        [ 0.1359,  0.0175, -0.0673,  ...,  0.0674,  0.0676,  0.1058],\n",
      "        [ 0.0790,  0.1343, -0.0293,  ...,  0.0344, -0.0971, -0.0509]],\n",
      "       requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "model = NeuralNetwork(50, 3)\n",
    "print(model.layers[0].weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1da9a35e-44f3-460c-90fe-304519736fd6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([30, 50])\n"
     ]
    }
   ],
   "source": [
    "print(model.layers[0].weight.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "57eadbae-90fe-43a3-a33f-c23a095ba42a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.1262,  0.1080, -0.1792]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "X = torch.rand((1, 50))\n",
    "out = model(X)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "48d720cb-ef73-4b7b-92e0-8198a072defd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.1262,  0.1080, -0.1792]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    out = model(X)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "10df3640-83c3-4061-a74d-08f07a5cc6ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.3113, 0.3934, 0.2952]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    out = torch.softmax(model(X), dim=1)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19858180-0f26-43a8-b2c3-7ed40abf9f85",
   "metadata": {},
   "source": [
    "## A.6 设置高效的数据加载器"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f98d8fc-5618-47a2-bc72-153818972a24",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A6/1.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b9dc2745-8be8-4344-80ef-325f02cda7b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = torch.tensor([\n",
    "    [-1.2, 3.1],\n",
    "    [-0.9, 2.9],\n",
    "    [-0.5, 2.6],\n",
    "    [2.3, -1.1],\n",
    "    [2.7, -1.5]\n",
    "])\n",
    "\n",
    "y_train = torch.tensor([0, 0, 0, 1, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "88283948-5fca-461a-98a1-788b6be191d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = torch.tensor([\n",
    "    [-0.8, 2.8],\n",
    "    [2.6, -1.6],\n",
    "])\n",
    "\n",
    "y_test = torch.tensor([0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "edf323e2-1789-41a0-8e44-f3cab16e5f5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "\n",
    "class ToyDataset(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        self.features = X\n",
    "        self.labels = y\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        one_x = self.features[index]\n",
    "        one_y = self.labels[index]        \n",
    "        return one_x, one_y\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.labels.shape[0]\n",
    "\n",
    "train_ds = ToyDataset(X_train, y_train)\n",
    "test_ds = ToyDataset(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "b7014705-1fdc-4f72-b892-d8db8bebc331",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "3ec6627a-4c3f-481a-b794-d2131be95eaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "torch.manual_seed(123)\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    dataset=train_ds,\n",
    "    batch_size=2,\n",
    "    shuffle=True,\n",
    "    num_workers=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "8c9446de-5e4b-44fa-bf9a-a63e2661027e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds = ToyDataset(X_test, y_test)\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    dataset=test_ds,\n",
    "    batch_size=2,\n",
    "    shuffle=False,\n",
    "    num_workers=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "99d4404c-9884-419f-979c-f659742d86ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1: tensor([[ 2.3000, -1.1000],\n",
      "        [-0.9000,  2.9000]]) tensor([1, 0])\n",
      "Batch 2: tensor([[-1.2000,  3.1000],\n",
      "        [-0.5000,  2.6000]]) tensor([0, 0])\n",
      "Batch 3: tensor([[ 2.7000, -1.5000]]) tensor([1])\n"
     ]
    }
   ],
   "source": [
    "for idx, (x, y) in enumerate(train_loader):\n",
    "    print(f\"Batch {idx+1}:\", x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "9d003f7e-7a80-40bf-a7fb-7a0d7dbba9db",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(\n",
    "    dataset=train_ds,\n",
    "    batch_size=2,\n",
    "    shuffle=True,\n",
    "    num_workers=0,\n",
    "    drop_last=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "4db4d7f4-82da-44a4-b94e-ee04665d9c3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1: tensor([[-1.2000,  3.1000],\n",
      "        [-0.5000,  2.6000]]) tensor([0, 0])\n",
      "Batch 2: tensor([[ 2.3000, -1.1000],\n",
      "        [-0.9000,  2.9000]]) tensor([1, 0])\n"
     ]
    }
   ],
   "source": [
    "for idx, (x, y) in enumerate(train_loader):\n",
    "    print(f\"Batch {idx+1}:\", x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb03ed57-df38-4ee0-a553-0863450df39b",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/MLNLP-World/LLMs-from-scratch-CN/main/imgs/A6/2.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d904ca82-e50f-4f3d-a3ac-fc6ca53dd00e",
   "metadata": {},
   "source": [
    "## A.7 典型的训练循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "93f1791a-d887-4fc5-a307-5e5bde9e06f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/003 | Batch 000/002 | Train/Val Loss: 0.75\n",
      "Epoch: 001/003 | Batch 001/002 | Train/Val Loss: 0.65\n",
      "Epoch: 002/003 | Batch 000/002 | Train/Val Loss: 0.44\n",
      "Epoch: 002/003 | Batch 001/002 | Train/Val Loss: 0.13\n",
      "Epoch: 003/003 | Batch 000/002 | Train/Val Loss: 0.03\n",
      "Epoch: 003/003 | Batch 001/002 | Train/Val Loss: 0.00\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "torch.manual_seed(123)\n",
    "model = NeuralNetwork(num_inputs=2, num_outputs=2)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.5)\n",
    "\n",
    "num_epochs = 3\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, labels) in enumerate(train_loader):\n",
    "\n",
    "        logits = model(features)\n",
    "        \n",
    "        loss = F.cross_entropy(logits, labels) # 损失函数\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "        ### 日志\n",
    "        print(f\"Epoch: {epoch+1:03d}/{num_epochs:03d}\"\n",
    "              f\" | Batch {batch_idx:03d}/{len(train_loader):03d}\"\n",
    "              f\" | Train/Val Loss: {loss:.2f}\")\n",
    "\n",
    "    model.eval()\n",
    "    # 插入可选的模型评估代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "00dcf57f-6a7e-4af7-aa5a-df2cb0866fa5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.8569, -4.1618],\n",
      "        [ 2.5382, -3.7548],\n",
      "        [ 2.0944, -3.1820],\n",
      "        [-1.4814,  1.4816],\n",
      "        [-1.7176,  1.7342]])\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(X_train)\n",
    "\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "19be7390-18b8-43f9-9841-d7fb1919f6fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[    0.9991,     0.0009],\n",
      "        [    0.9982,     0.0018],\n",
      "        [    0.9949,     0.0051],\n",
      "        [    0.0491,     0.9509],\n",
      "        [    0.0307,     0.9693]])\n",
      "tensor([0, 0, 0, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "torch.set_printoptions(sci_mode=False)\n",
    "probas = torch.softmax(outputs, dim=1)\n",
    "print(probas)\n",
    "\n",
    "predictions = torch.argmax(probas, dim=1)\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "07e7e530-f8d3-429c-9f5e-cf8078078c0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0, 0, 0, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "predictions = torch.argmax(outputs, dim=1)\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "5f756f0d-63c8-41b5-a5d8-01baa847e026",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([True, True, True, True, True])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions == y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "da274bb0-f11c-4c81-a880-7a031fbf2943",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(5)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(predictions == y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "16d62314-8dee-45b0-8f55-9e5aae2b24f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(model, dataloader):\n",
    "\n",
    "    model = model.eval()\n",
    "    correct = 0.0\n",
    "    total_examples = 0\n",
    "    \n",
    "    for idx, (features, labels) in enumerate(dataloader):\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            logits = model(features)\n",
    "        \n",
    "        predictions = torch.argmax(logits, dim=1)\n",
    "        compare = labels == predictions\n",
    "        correct += torch.sum(compare)\n",
    "        total_examples += len(compare)\n",
    "\n",
    "    return (correct / total_examples).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "4f6c9c17-2a5f-46c0-804b-873f169b729a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_accuracy(model, train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "311ed864-e21e-4aac-97c7-c6086caef27a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_accuracy(model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d5cd469-3a45-4394-944b-3ce543f41dac",
   "metadata": {},
   "source": [
    "## A.8 保存和加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b013127d-a2c3-4b04-9fb3-a6a7c88d83c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), \"model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "b2b428c2-3a44-4d91-97c4-8298cf2b51eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = NeuralNetwork(2, 2) # 需要与最初保存的模型完全匹配\n",
    "model.load_state_dict(torch.load(\"model.pth\", weights_only=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f891c013-43da-4a05-973d-997be313d2d8",
   "metadata": {},
   "source": [
    "## A.9 使用 GPU 优化训练性能"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e68ae888-cabf-49c9-bad6-ecdce774db57",
   "metadata": {},
   "source": [
    "### A.9.1 在GPU设备运行PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "141c845f-efe3-4614-b376-b8b7a9a2c887",
   "metadata": {},
   "source": [
    "见 [code-part2.ipynb](code-part2.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99811829-b817-42ea-b03e-d35374debcc0",
   "metadata": {},
   "source": [
    "### A.9.2 单个GPU训练"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b21456c-4af7-440f-9e78-37770277b5bc",
   "metadata": {},
   "source": [
    "见 [code-part2.ipynb](code-part2.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db6eb2d1-a341-4489-b04b-635c26945333",
   "metadata": {},
   "source": [
    "### A.9.3 使用多个GPU训练"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d049a81-5fb0-49b5-9d6a-17a9976d8520",
   "metadata": {},
   "source": [
    "见 [DDP-script.py](DDP-script.py)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
